import gymnasium
import argparse
import numpy as np
from scipy.ndimage import convolve
from tensorboardX import SummaryWriter
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
os.environ["MUJOCO_GL"] ="egl" #"egl"

import matplotlib.pyplot as plt
import torchvision.utils as vutils
from matplotlib.backends.backend_pdf import PdfPages
import torch
from dm_control import suite
import envs.wrappers as wrappers
import envs.dmc as dmc

def sample_episode(env,writer):
    current_obs,_= env.reset()

    current_obs = current_obs/255.0
    tau = 0.05
    background = np.zeros_like(current_obs)
    frame_list = []
    tensor_frame_list = []
    while True:
        action = env.action_space.sample()
        obs, reward, done, truncated, info = env.step(action)
        obs = obs/255.0
        background = (1 - tau) * background + tau * obs
        done = done or truncated
        image = obs 
        frame_list.append(image)
        tensor_frame_list.append(image)

        
        cv2.waitKey(100)
        current_obs = obs
        if done:

            frame_tensor = torch.tensor(np.array(frame_list)).unsqueeze(0).permute(0, 1, 4, 2, 3)
            writer.add_video('Boxing_diff_video', frame_tensor, fps=30)
           
            diff_frame = extract_dynamic_mask(torch.tensor(tensor_frame_list).unsqueeze(0)).permute(0, 1, 4, 2, 3)
            sampled_images = diff_frame.squeeze(0)[20:26]
            grid = vutils.make_grid(sampled_images , nrow=6) 
            writer.add_image('Boxing_diff_figure', grid)
       

            single_diff_frame = dynamic_mask(torch.tensor(tensor_frame_list).unsqueeze(0)).permute(0, 1, 4, 2, 3)
            single_sampled_images = single_diff_frame.squeeze(0)[20:26]
            single_grid = vutils.make_grid(single_sampled_images , nrow=6) 
            writer.add_image('Boxing_single_diff_figure', single_grid)



            original_images = frame_tensor.squeeze(0)[20:26]
            grid2= vutils.make_grid(original_images , nrow=6) 
            writer.add_image('Boxing_figure', grid2)


            all_three =  torch.torch.cat([original_images,single_sampled_images,sampled_images], 0)
            gridall_three = vutils.make_grid(all_three , nrow=6) 
            writer.add_image('all_figure', gridall_three)
            pdf_path = "gridall_three.pdf"
            with PdfPages(pdf_path) as pdf:
                plt.figure(figsize=(8, 8))
                plt.axis("off")
                plt.imshow(gridall_three.permute(1, 2, 0).cpu().numpy())
                pdf.savefig(bbox_inches="tight", pad_inches=0)
                plt.close()


            writer.add_image('Boxing_diff_one_figure', diff_frame.squeeze(0)[10])
 
            writer.add_image('Boxing_one_figure', frame_tensor.squeeze(0)[10])




            writer.close()
            break
 
 

def extract_dynamic_mask(x):
    assert x.dim() == 5
    error = 1e-3
    dilation_size = 3
    dilation_kernel = torch.ones((1, 1, dilation_size, dilation_size), dtype=torch.float32)
    dilation_rate = 1
    diff = torch.abs(x[:, 1:] - x[:, :-1]) 
    dynamic_mask = (diff > error).float() 

    B, T, H, W, C = dynamic_mask.shape
    dynamic_mask = dynamic_mask.permute(0, 1, 4, 2, 3).reshape(B * T * C, 1, H, W)
    dilated_mask = F.conv2d(dynamic_mask, dilation_kernel, padding=(dilation_size // 2), dilation= dilation_rate)
 
    dilated_mask = (dilated_mask > 0).float() 
    dilated_mask = dilated_mask.reshape(B, T, C, H, W).permute(0, 1, 3, 4, 2)

    return dilated_mask*x[:, 1:] 

def dynamic_mask(x):
    assert x.dim() == 5
    error = 1e-3
    dilation_size = 3
    dilation_kernel = torch.ones((1, 1, dilation_size, dilation_size), dtype=torch.float32)
    dilation_rate = 1
    diff = torch.abs(x[:, 1:] - x[:, :-1]) 
    dynamic_mask = (diff > error).float()

    B, T, H, W, C = dynamic_mask.shape
    #dynamic_mask = dynamic_mask.permute(0, 1, 4, 2, 3) 
  
    return dynamic_mask*x[:, 1:] 


if __name__ == "__main__":

    s = SummaryWriter(log_dir='')

    env = dmc.DeepMindControl(
             'finger_spin' , 2, (64, 64), seed=0
        )
    env = wrappers.NormalizeActions(env)
    #env = gymnasium.make("ALE/Pong-v5", full_action_space=True)
    episode_steps = sample_episode(
        env=env,
        writer = s
    )
